{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4KwjZCTXRPOh"
   },
   "source": [
    "# Dolly vs. Pythia Comparison\n",
    "\n",
    "This notebooks uses Phoenix to visualize the embeddings of prompt-response pairs generated using Dolly and Pythia."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install arize-phoenix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "k7bWKFunRc2I"
   },
   "source": [
    "Import libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import hashlib\n",
    "import re\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import phoenix as px\n",
    "\n",
    "pd.set_option(\"display.max_colwidth\", None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pythia_file = \"https://storage.googleapis.com/arize-assets/fixtures/Embeddings/GENERATIVE/pythia-2.8b_2023-05-27_16-54-20.csv\"\n",
    "pythia_file = \"https://storage.googleapis.com/arize-assets/fixtures/Embeddings/GENERATIVE/pythia-2.8b_2023-06-01_02-51-40.csv\"\n",
    "# pythia_file = \"https://storage.googleapis.com/arize-assets/fixtures/Embeddings/GENERATIVE/pythia-2.8b-deduped_2023-06-01_06-37-21.csv\"\n",
    "# pythia_file = \"https://storage.googleapis.com/arize-assets/fixtures/Embeddings/GENERATIVE/pythia-2.8b_2023-06-01_06-35-09.csv\"\n",
    "dolly_file = \"https://storage.googleapis.com/arize-assets/fixtures/Embeddings/GENERATIVE/dolly-v2-3b_2023-06-01_02-12-20.csv\"\n",
    "# dolly_file = \"https://storage.googleapis.com/arize-assets/fixtures/Embeddings/GENERATIVE/dolly-v2-3b_2023-05-27_18-23-02.csv\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0JXVUnRNWH5O"
   },
   "source": [
    "Download your data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def string_to_array(s):\n",
    "    numbers = re.findall(r\"[-+]?\\d*\\.\\d+e[+-]?\\d+|[-+]?\\d+\\.\\d*|[-+]?\\d+\", s)\n",
    "    return np.array([float(num) for num in numbers])\n",
    "\n",
    "\n",
    "pythia_df = pd.read_csv(pythia_file)\n",
    "pythia_df[\"prompt_embedding\"] = pythia_df[\"prompt_embedding_vec\"].apply(string_to_array)\n",
    "pythia_df[\"paragraph_embedding\"] = pythia_df[\"paragraph_embedding_vec\"].apply(string_to_array)\n",
    "pythia_df = pythia_df.drop(\"paragraph_embedding_vec\", axis=1)\n",
    "pythia_df = pythia_df.drop(\"prompt_embedding_vec\", axis=1)\n",
    "\n",
    "dolly_df = pd.read_csv(dolly_file)\n",
    "dolly_df[\"prompt_embedding\"] = dolly_df[\"prompt_embedding_vec\"].apply(string_to_array)\n",
    "dolly_df[\"paragraph_embedding\"] = dolly_df[\"paragraph_embedding_vec\"].apply(string_to_array)\n",
    "dolly_df = dolly_df.drop(\"paragraph_embedding_vec\", axis=1)\n",
    "dolly_df = dolly_df.drop(\"prompt_embedding_vec\", axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ThGOrwKHWgol"
   },
   "source": [
    "View the first few rows of each dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dolly_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pythia_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "132ADlmYRrCj"
   },
   "source": [
    "Compute a unique ID for each prompt by hashing the prompt. This allows you to correspond multiple datapoints responding to the same prompt to see how the response \"unfolds\" in the latent space of the model.\n",
    "\n",
    "Here are some interesting prompt hashes to check out once you launch Phoenix.\n",
    "\n",
    "(pythia-2.8b_2023-06-01_02-51-40)\n",
    "\n",
    "- 934bc7ae9b678cb0ca42ecfc45239716 (Bernie Sanders)\n",
    "- 934bc7ae9b678cb0ca42ecfc45239716 (pollution)\n",
    "- cbee03fe7a6de75418dc69304f54b478\n",
    "- a75eae577fe7f81538237e5b7c9eeeed (AWS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hash_string(string):\n",
    "    md5_hash = hashlib.md5()\n",
    "    md5_hash.update(string.encode(\"utf-8\"))\n",
    "    return md5_hash.hexdigest()\n",
    "\n",
    "\n",
    "pythia_df[\"prompt_id\"] = pythia_df.prompt.map(hash_string)\n",
    "dolly_df[\"prompt_id\"] = dolly_df.prompt.map(hash_string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pythia_df[\"evals\"] = pd.to_numeric(pythia_df[\"evals\"], errors=\"coerce\").fillna(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "S9QCDFDUTNCR"
   },
   "source": [
    "Find the mean evaluation score for the two datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dolly_df[\"evals\"].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pythia_df[\"evals\"].mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vYi8S4NTWZYp"
   },
   "source": [
    "Launch Phoenix with one dataset, then the other."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "schema = px.Schema(\n",
    "    prompt_column_names=px.EmbeddingColumnNames(\n",
    "        raw_data_column_name=\"prompt\", vector_column_name=\"prompt_embedding\"\n",
    "    ),\n",
    "    response_column_names=px.EmbeddingColumnNames(\n",
    "        raw_data_column_name=\"response_paragraph\", vector_column_name=\"paragraph_embedding\"\n",
    "    ),\n",
    "    tag_column_names=[\n",
    "        \"prompt_category\",\n",
    "        \"conversation_id\",\n",
    "        \"response_capitalized\",\n",
    "        \"response_text\",\n",
    "        \"prompt_id\",\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pythia_ds = px.Dataset(dataframe=pythia_df, schema=schema, name=\"pythia\")\n",
    "dolly_ds = px.Dataset(dataframe=dolly_df, schema=schema, name=\"dolly\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "session = px.launch_app(pythia_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "session = px.launch_app(dolly_ds)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
